1ddecf
@@ -173,7 +173,7 @@
public TrimResult trimFields(
       }
 
       Set<RelDataTypeField> inputExtraFields =
-              Collections.<RelDataTypeField>emptySet();
+          Collections.<RelDataTypeField>emptySet();
       TrimResult trimResult =
           trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
       newInputs.add(trimResult.left);
@@ -212,7 +212,7 @@
public TrimResult trimFields(
 
     // Build new join.
     final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle(
-            mapping, newInputs.toArray(new RelNode[newInputs.size()]));
+        mapping, newInputs.toArray(new RelNode[newInputs.size()]));
     RexNode newConditionExpr = conditionExpr.accept(shuttle);
 
     List<RexNode> newJoinFilters = Lists.newArrayList();
@@ -222,14 +222,14 @@
public TrimResult trimFields(
     }
 
     final RelDataType newRowType = RelOptUtil.permute(join.getCluster().getTypeFactory(),
-            join.getRowType(), mapping);
+        join.getRowType(), mapping);
     final RelNode newJoin = new HiveMultiJoin(join.getCluster(),
-            newInputs,
-            newConditionExpr,
-            newRowType,
-            join.getJoinInputs(),
-            join.getJoinTypes(),
-            newJoinFilters);
+        newInputs,
+        newConditionExpr,
+        newRowType,
+        join.getJoinInputs(),
+        join.getJoinTypes(),
+        newJoinFilters);
 
     return new TrimResult(newJoin, mapping);
   }
@@ -271,7 +271,7 @@
public TrimResult trimFields(DruidQuery dq, ImmutableBitSet fieldsUsed,
   }
 
   private static RelNode project(DruidQuery dq, ImmutableBitSet fieldsUsed,
-          Set<RelDataTypeField> extraFields, RelBuilder relBuilder) {
+      Set<RelDataTypeField> extraFields, RelBuilder relBuilder) {
     final int fieldCount = dq.getRowType().getFieldCount();
     if (fieldsUsed.equals(ImmutableBitSet.range(fieldCount))
         && extraFields.isEmpty()) {
@@ -321,7 +321,7 @@
private boolean isRexLiteral(final RexNode rexNode) {
   // because if not and it consist of keys (unique + not null OR pk), we can safely remove rest of the columns
   // if those are columns are not being used further up
   private ImmutableBitSet generateGroupSetIfCardinalitySame(final Aggregate aggregate,
-                                        final ImmutableBitSet originalGroupSet, final ImmutableBitSet fieldsUsed) {
+      final ImmutableBitSet originalGroupSet, final ImmutableBitSet fieldsUsed) {
 
     RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
     RelMetadataQuery mq = aggregate.getCluster().getMetadataQuery();
@@ -473,7 +473,7 @@
private ImmutableBitSet generateNewGroupset(Aggregate aggregate, ImmutableBitSet
    *
    */
   private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fieldsUsed,
-                                          Set<RelDataTypeField> extraFields) {
+      ImmutableBitSet aggCallFields) {
     if ((aggregate.getIndicatorCount() > 0)
         || (aggregate.getGroupSet().isEmpty())
         || fieldsUsed.contains(aggregate.getGroupSet())) {
@@ -503,7 +503,7 @@
private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fie
 
     if (allConstants) {
       for (int i = 0; i < rowType.getFieldCount(); i++) {
-        if (aggregate.getGroupSet().get(i)) {
+        if (aggregate.getGroupSet().get(i) && !aggCallFields.get(i)) {
           newProjects.add(rexBuilder.makeLiteral(true));
         } else {
           newProjects.add(rexBuilder.makeInputRef(input, i));
@@ -512,7 +512,7 @@
private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fie
       relBuilder.push(input);
       relBuilder.project(newProjects);
       Aggregate newAggregate = new HiveAggregate(aggregate.getCluster(), aggregate.getTraitSet(), relBuilder.build(),
-                                                 aggregate.getGroupSet(), null, aggregate.getAggCallList());
+          aggregate.getGroupSet(), null, aggregate.getAggCallList());
       return newAggregate;
     }
     return aggregate;
@@ -533,24 +533,33 @@
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Se
     //
     // But group and indicator fields stay, even if they are not used.
 
-    aggregate = rewriteGBConstantKeys(aggregate, fieldsUsed, extraFields);
+    // Compute which input fields are used.
 
-    final RelDataType rowType = aggregate.getRowType();
 
-    // Compute which input fields are used.
-    // 1. group fields are always used
-    final ImmutableBitSet.Builder inputFieldsUsed =
-        aggregate.getGroupSet().rebuild();
-    // 2. agg functions
+    // agg functions
+    // agg functions are added first (before group sets) because rewriteGBConstantsKeys
+    // needs it
+    final ImmutableBitSet.Builder aggCallFieldsUsedBuilder =  ImmutableBitSet.builder();
     for (AggregateCall aggCall : aggregate.getAggCallList()) {
       for (int i : aggCall.getArgList()) {
-        inputFieldsUsed.set(i);
+        aggCallFieldsUsedBuilder.set(i);
       }
       if (aggCall.filterArg >= 0) {
-        inputFieldsUsed.set(aggCall.filterArg);
+        aggCallFieldsUsedBuilder.set(aggCall.filterArg);
       }
     }
 
+    // transform if group by contain constant keys
+    ImmutableBitSet aggCallFieldsUsed = aggCallFieldsUsedBuilder.build();
+    aggregate = rewriteGBConstantKeys(aggregate, fieldsUsed, aggCallFieldsUsed);
+
+    // add group fields
+    final ImmutableBitSet.Builder inputFieldsUsed =  aggregate.getGroupSet().rebuild();
+    inputFieldsUsed.addAll(aggCallFieldsUsed);
+
+
+    final RelDataType rowType = aggregate.getRowType();
+
     // Create input with trimmed columns.
     final RelNode input = aggregate.getInput();
     final Set<RelDataTypeField> inputExtraFields = Collections.emptySet();
@@ -582,7 +591,7 @@
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Se
     if (input == newInput
         && fieldsUsed.equals(ImmutableBitSet.range(rowType.getFieldCount()))) {
       return result(aggregate,
-                    Mappings.createIdentity(rowType.getFieldCount()));
+          Mappings.createIdentity(rowType.getFieldCount()));
     }
 
     // update the group by keys based on inputMapping
@@ -615,7 +624,7 @@
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Se
     } else {
       newGroupSets = ImmutableList.copyOf(
           Iterables.transform(aggregate.getGroupSets(),
-            input1 -> Mappings.apply(inputMapping, input1)));
+              input1 -> Mappings.apply(inputMapping, input1)));
     }
 
     // Populate mapping of where to find the fields. System, group key and
@@ -641,8 +650,8 @@
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Se
             : relBuilder.field(Mappings.apply(inputMapping, aggCall.filterArg));
         RelBuilder.AggCall newAggCall =
             relBuilder.aggregateCall(aggCall.getAggregation(),
-                                     aggCall.isDistinct(), aggCall.isApproximate(),
-                                     filterArg, aggCall.name, args);
+                aggCall.isDistinct(), aggCall.isApproximate(),
+                filterArg, aggCall.name, args);
         mapping.set(j, updatedGroupCount +  newAggCallList.size());
         newAggCallList.add(newAggCall);
       }
